{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "01990ab9",
   "metadata": {},
   "source": [
    "# Implementing recurrent networks\n",
    "\n",
    "Unlike convolutional networks (which required us to implement a new kind of operation), recurrent networks in theory are quite straightforward to implement: although the particular details of the \"cell\" for a more complex recurrent network like an LSTM seem a bit complex, it is ultimately just a collection of operators that are fairly easy to chain together in a automatic differentiation tool.  However, there _are_ a number of considerations to keep in mind when implementing recurrent networks efficiently, most steming from the fact that they are fundamentally _sequential_ models.  This means that, unlike the \"normal\" deep network, where we think of all operations being \"easily\" parallelizable over a batch of data, the data input to an LSTM need to be fed in one at a time: we cannot process the second element in a sequence until we process the first, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bcf7863",
   "metadata": {},
   "source": [
    "## Implementing the LSTM cell\n",
    "\n",
    "One thing we pointed out in the notes, which I want to highlight here, is that matrices behind the LSTM cell are not as complex as they look.  Let's look at a typical equation for an LSTM cell as it's usually written in document or papers, in this case from the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html.\n",
    "\n",
    "\\begin{array}{ll} \\\\\n",
    "        i_t = \\sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\\\\n",
    "        f_t = \\sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\\\\n",
    "        g_t = \\tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\\\\n",
    "        o_t = \\sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\\\\n",
    "        c_t = f_t \\odot c_{t-1} + i_t \\odot g_t \\\\\n",
    "        h_t = o_t \\odot \\tanh(c_t) \\\\\n",
    "    \\end{array}\n",
    "\n",
    "Those equations look awful confusing (plus they have the wrong indices in the subscripts for the $W$ terms).  But the first things to realize, which I emphasized in the notes, is that there aren't actually 8 different weights: there are two different weight.  You should think of the vector\n",
    "\\begin{bmatrix}\n",
    "i \\\\ f \\\\ g \\\\ o\n",
    "\\end{bmatrix}\n",
    "not four separate vectors, but as a single vector that just happens to be four times the length of the hidden unit, i.e., we really have the update\n",
    "\\begin{equation}\n",
    "\\begin{bmatrix}\n",
    "i \\\\ f \\\\ g \\\\ o\n",
    "\\end{bmatrix}\n",
    "= \\begin{pmatrix} \\sigma \\\\ \\sigma \\\\ \\tanh \\\\ \\sigma \\end{pmatrix} (W_{hi} x + W_{hh} h + b)\n",
    "\\end{equation}\n",
    "(where I'm unilaterally deciding to get rid of the odd choice to have two different bias terms, and relabeling the subscripts on $W$ properly).\n",
    "\n",
    "This is how PyTorch does it internally anyway, as you can see if you inspect the actual class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "aec48f9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([400, 100])\n",
      "torch.Size([400, 20])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "model = nn.LSTMCell(20,100)\n",
    "print(model.weight_hh.shape)\n",
    "print(model.weight_ih.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df14e72e",
   "metadata": {},
   "source": [
    "We could define our own LSTM cell using something like the following function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "e74648d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1/(1+np.exp(-x))\n",
    "\n",
    "def lstm_cell(x, h, c, W_hh, W_ih, b):\n",
    "    i,f,g,o = np.split(W_ih@x + W_hh@h + b, 4)\n",
    "    i,f,g,o = sigmoid(i), sigmoid(f), np.tanh(g), sigmoid(o)\n",
    "    c_out = f*c + i*g, \n",
    "    h_out = o * np.tanh(c_out)\n",
    "    return h_out, c_out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "331b14fc",
   "metadata": {},
   "source": [
    "Let's confirm that this gives the same results as PyTorch's version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "4daadf59",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.random.randn(1,20).astype(np.float32)\n",
    "h0 = np.random.randn(1,100).astype(np.float32)\n",
    "c0 = np.random.randn(1,100).astype(np.float32)\n",
    "\n",
    "h_, c_ = model(torch.tensor(x), (torch.tensor(h0), torch.tensor(c0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d2fb74ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.7560379e-07 3.7099852e-07\n"
     ]
    }
   ],
   "source": [
    "h, c = lstm_cell(x[0], h0[0], c0[0], \n",
    "                 model.weight_hh.detach().numpy(), \n",
    "                 model.weight_ih.detach().numpy(), \n",
    "                 (model.bias_hh + model.bias_ih).detach().numpy())\n",
    "\n",
    "print(np.linalg.norm(h_.detach().numpy() - h), \n",
    "      np.linalg.norm(c_.detach().numpy() - c))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "587ec9db",
   "metadata": {},
   "source": [
    "### Iterating over a sequence\n",
    "\n",
    "We can run the function on a whole sequence simply by iterating over this cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "34e81cf8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([400, 100])\n",
      "torch.Size([400, 20])\n"
     ]
    }
   ],
   "source": [
    "model = nn.LSTM(20, 100, num_layers = 1)\n",
    "print(model.weight_hh_l0.shape)\n",
    "print(model.weight_ih_l0.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d88e6f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.random.randn(50,20).astype(np.float32)\n",
    "h0 = np.random.randn(1,100).astype(np.float32)\n",
    "c0 = np.random.randn(1,100).astype(np.float32)\n",
    "H_, (hn_, cn_) = model(torch.tensor(X)[:,None,:], \n",
    "                       (torch.tensor(h0)[:,None,:], \n",
    "                        torch.tensor(c0)[:,None,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "7df03842",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lstm(X, h, c, W_hh, W_ih, b):\n",
    "    H = np.zeros((X.shape[0], h.shape[0]))\n",
    "    for t in range(X.shape[0]):\n",
    "        h, c = lstm_cell(X[t], h, c, W_hh, W_ih, b)\n",
    "        H[t,:] = h\n",
    "    return H, c\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "1f75eabf",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 1 is different from 100)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[39], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m H, cn \u001b[38;5;241m=\u001b[39m \u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh0\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mc0\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m             \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_hh_l0\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[43m             \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_ih_l0\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m             \u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_hh_l0\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_ih_l0\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[29], line 4\u001b[0m, in \u001b[0;36mlstm\u001b[0;34m(X, h, c, W_hh, W_ih, b)\u001b[0m\n\u001b[1;32m      2\u001b[0m H \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros((X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], h\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]))\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]):\n\u001b[0;32m----> 4\u001b[0m     h, c \u001b[38;5;241m=\u001b[39m \u001b[43mlstm_cell\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m[\u001b[49m\u001b[43mt\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mW_hh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mW_ih\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      5\u001b[0m     H[t,:] \u001b[38;5;241m=\u001b[39m h\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m H, c\n",
      "Cell \u001b[0;32mIn[36], line 7\u001b[0m, in \u001b[0;36mlstm_cell\u001b[0;34m(x, h, c, W_hh, W_ih, b)\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlstm_cell\u001b[39m(x, h, c, W_hh, W_ih, b):\n\u001b[0;32m----> 7\u001b[0m     i,f,g,o \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msplit(W_ih\u001b[38;5;129m@x\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[43mW_hh\u001b[49m\u001b[38;5;129;43m@h\u001b[39;49m \u001b[38;5;241m+\u001b[39m b, \u001b[38;5;241m4\u001b[39m)\n\u001b[1;32m      8\u001b[0m     i,f,g,o \u001b[38;5;241m=\u001b[39m sigmoid(i), sigmoid(f), np\u001b[38;5;241m.\u001b[39mtanh(g), sigmoid(o)\n\u001b[1;32m      9\u001b[0m     c_out \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m*\u001b[39mc \u001b[38;5;241m+\u001b[39m i\u001b[38;5;241m*\u001b[39mg, \n",
      "\u001b[0;31mValueError\u001b[0m: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 1 is different from 100)"
     ]
    }
   ],
   "source": [
    "H, cn = lstm(X, h0[0], c0[0], \n",
    "             model.weight_hh_l0.detach().numpy(), \n",
    "             model.weight_ih_l0.detach().numpy(), \n",
    "             (model.bias_hh_l0 + model.bias_ih_l0).detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "8c085b1c",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'H' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[38], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(np\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mnorm(\u001b[43mH\u001b[49m \u001b[38;5;241m-\u001b[39m H_[:,\u001b[38;5;241m0\u001b[39m,:]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()),\n\u001b[1;32m      2\u001b[0m       np\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mnorm(cn \u001b[38;5;241m-\u001b[39m cn_[\u001b[38;5;241m0\u001b[39m,\u001b[38;5;241m0\u001b[39m,:]\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'H' is not defined"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(H - H_[:,0,:].detach().numpy()),\n",
    "      np.linalg.norm(cn - cn_[0,0,:].detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e3cbf20",
   "metadata": {},
   "source": [
    "## Batching efficiently\n",
    "\n",
    "Everything works above, but if you looked a bit closer at the example above, you might have noticed some pretty odd tensor sizes.  We had to pass a tensor of size `(50,1,20)` to the PyTorch LSTM to get a similar result as for our own (more intuitively-sized?) function.  What's happening here?\n",
    "\n",
    "The basic issue is that, as mentioned above, the LSTM as we implemented it in our own function is inherently sequential: we need to run the `lstm_cell` call for time t=1 before we move to time t=2, before time t=3, etc, because the hidden unit computed at time t=1 becomes is input to time t=2.  But each `lstm_cell` call operation as performed above is fundamentally a matrix-vector operation, and as we have seen many times before, we'd ideally like to turn this into a matrix-matrix operation so as to perform more efficient compuation.\n",
    "\n",
    "So, just like for the case of the MLP, for example, we're going to employ mini-batches to achieve this computational efficiency.  But the key point is that the examples in these minibatches cannot be from the same sequence: they need to be from multiple _different_ sequences (or often, in practice, from locations far apart in a single sequence).\n",
    "\n",
    "### The form of batched samples for LSTMs\n",
    "\n",
    "Once we move to minibactches of samples, each with a number of timesteps (we'll assume each sample has the same number of timesteps for now, but this will be dealt with shortly), then we need to store input that has a batch, time, and input dimension.  Most naturally, this would look something like the following:\n",
    "\n",
    "    X[NUM_BATCHES][NUM_TIMESTEPS][INPUT_SIZE]\n",
    "    \n",
    "which we could call NTC format (this isn't all that common for LSTMs, but it's analogous to the NHWC format we discussed for images).  However, PyTorch natively uses the \"TNC\" format for LSTMs, that is, it stores the tensor in the order:\n",
    "\n",
    "    X[NUM_TIMESTEPS][NUM_BATCHES][INPUT_SIZE]\n",
    "    \n",
    "Why does it do this?  The batch dimension is _always_ first in any other setting in PyTorch, and indeed there's even now an option to use this \"batch first\" format for LSTMs, though it isn't the default.\n",
    "\n",
    "The reason is due to memory locality.  In order to effectively batch the operations of an LSTM into matrix-matrix product form, we want to perform the matrix multiplications\n",
    "\n",
    "\\begin{equation}\n",
    "\\begin{bmatrix}\n",
    "I \\\\ F \\\\ G \\\\ O\n",
    "\\end{bmatrix}\n",
    "= \\begin{pmatrix} \\sigma \\\\ \\sigma \\\\ \\tanh \\\\ \\sigma \\end{pmatrix} (X W_{hi} + H W_{hh} + b)\n",
    "\\end{equation}\n",
    "\n",
    "where we're considering $X$ here to be an $N \\times C$ matrix, and $H,I,F,G,O$ to be $N \\times K$ where $K$ is the hidden dimension (we're also implicitly transposing the matrices from their forms above ... this is why PyTorch lists them as $W_{ih}$ in the docs, even though that's not really correct for the vector form it's showing).\n",
    "\n",
    "But in order to have each $X$ and $H$ (over all batches, for a single timestep) be contiguous in memory, we need to use the THC ordering and then select `X[t]` and `H[t]` as the relevant indices.  Let's see how this looks in an implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "f510b5bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lstm_cell(x, h, c, W_hh, W_ih, b):\n",
    "    i,f,g,o = np.split(x@W_ih + h@W_hh + b[None,:], 4, axis=1)\n",
    "    i,f,g,o = sigmoid(i), sigmoid(f), np.tanh(g), sigmoid(o)\n",
    "    c_out = f*c + i*g\n",
    "    h_out = o * np.tanh(c_out)\n",
    "    return h_out, c_out\n",
    "\n",
    "def lstm(X, h, c, W_hh, W_ih, b):\n",
    "    H = np.zeros((X.shape[0], X.shape[1], h.shape[1]))\n",
    "    for t in range(X.shape[0]):\n",
    "        h, c = lstm_cell(X[t], h, c, W_hh, W_ih, b)\n",
    "        H[t,:,:] = h\n",
    "    return H, c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "e1987132",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.random.randn(50,80,20).astype(np.float32)\n",
    "h0 = np.random.randn(80,100).astype(np.float32)\n",
    "c0 = np.random.randn(80,100).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "6bd29623",
   "metadata": {},
   "outputs": [],
   "source": [
    "H_, (hn_, cn_) = model(torch.tensor(X), \n",
    "                       (torch.tensor(h0)[None,:,:], \n",
    "                        torch.tensor(c0)[None,:,:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "530fc579",
   "metadata": {},
   "outputs": [],
   "source": [
    "H, cn = lstm(X, h0, c0,\n",
    "             model.weight_hh_l0.detach().numpy().T, \n",
    "             model.weight_ih_l0.detach().numpy().T, \n",
    "             (model.bias_hh_l0 + model.bias_ih_l0).detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "0c73f8e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.435998514003646e-06 1.3124086e-06\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(H - H_.detach().numpy()),\n",
    "      np.linalg.norm(cn - cn_[0].detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2027c5b3",
   "metadata": {},
   "source": [
    "If we _were_ to store the matrices in NTC ordering, and e.g., using the matrix multiplication we will consider in  Homework 3, where matrices need to be compact in memory before performing multiplication, we would have to be copying memory around during each update.  The TNC format fixes this (and even if we _were_ to develop a more efficient multiplication strategy that could directly consider strided matrices, the NTC format would still sacrifice memory locality).  PyTorch (and needle, as you will implement on Homework 4) will thus the TNC route, and sacrifice a few people being confused the first time they use LSTMs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa1b49c2",
   "metadata": {},
   "source": [
    "### Packed sequences\n",
    "\n",
    "There is still one substantial problem with the approach above: the tensor form requires that each sequence in the batch be the same size.  This is often explicitly not the case for RNNs, where e.g., one may want to use an RNN to process individual sentences of text, individual audio signals, etc.  A large _benefit_ of the RNN for these settings is that the sequences can all be different lengths, yet an RNN can process them all similarly.  Thus, we have a common setting where the different sequences in a minibatch might be different lengths.\n",
    "\n",
    "One \"simple\" way to deal with this is simply to zero-pad the input sequences to the size of the longest sequence.  That is, we can place all the sequences in a single `X[MAX_TIMESTEPS][BATCH][DIMENION]`, replace all inputs that occur after the end of each sequence with zeros, and then after-the-fact extract the hidden unit representation at the effective end of each sequence (or at all valid points in the sequence).  Since this takes advantage of \"full, equal sized\" matrix multiplications at each step, this is reasonable solution for it's simplicity).\n",
    "\n",
    "However, if the sequences in a batch are _very_ different sizes, it should be acknowledged that this can ultimately be inefficient to run all the operations of the LSTM on what amount to a lot of meaningless data full of just zeros.  To get around this, an alternative is to support \"packed sequences\".  This represents the input as a 2D tensor\n",
    "    \n",
    "    X[sum_batch TIMESTEPS(batch)][DIMENSION]\n",
    "\n",
    "that lumps together elements in both the batch and time dimensions.\n",
    "\n",
    "In order to still exploit contiguous memory, we still want to group together elements by timestep, so this format contains first all inputs for all sequences at time 1 (they will all exist here across all samples), followed by all inputs for all sequences at time 2 (only for those that actually exist), etc.  Then, in addition, there needs to be a \n",
    "\n",
    "    int time_indexes[MAX_TIMESTEPS]\n",
    "    \n",
    "variable that points to the starting index of the batch for each timestep.\n",
    "\n",
    "We won't include the code to do this here, as it involves a bit more bookkeeping to keep everything in place, but it should be apparent that for the cost of a bit more additional indexing, you can run the LSTM on _only_ those portions of a sequence that actually exist.  Whether or not this is ultimately beneficial depends on how much you're able to saturate the compute hardware at the later \"sparser\" stages of the processing."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64339e07",
   "metadata": {},
   "source": [
    "## Training LSTMs: Truncated BPTT and hidden repackaging\n",
    "\n",
    "Now that we have covered the basics of LSTM creation, we're going to briefly mention how to train these systems in practice.  As this involves actually running the LSTM code in a autodiff tool, we're going to instead just include pseudo-code for these ideas here, but you'll implement them within needle for the final homework.\n",
    "\n",
    "First, we should emphasize one main point, that the majority of implementing training in a recurrent neural network \"just\" involves running the RNN under an automatic differentiation tool, and using autodiff to find the gradient of all parameters with respect to some loss.  That is, using roughly the notation from the above (i.e., with the `lstm` call that run over an entire (batch of) sequences, we could summarize the training procedure as the following:\n",
    "\n",
    "```python\n",
    "def train_lstm(X, Y, h0, c0, parameters)\n",
    "    H, cn = lstm(X, h0, c0, parameters)\n",
    "    l = loss(H, Y)\n",
    "    l.backward()\n",
    "    opt.step()\n",
    "```\n",
    "    \n",
    "For a multi-layer LSTM, we actually have some choice in determinining in what \"order\" we run it: do we run a full layer first (over all time), and then iterate over depth?  Or do we run all depth first for a single time step, and then iterate over time?  While both are options, it's conceptually simpler (because we can re-use the same function above) to follow the first option, i.e., to just run each LSTM over the full sequence and then iterate over layers, i.e.,\n",
    "\n",
    "```python\n",
    "def train_lstm(X, Y, h0, c0, parameters)\n",
    "    H = X\n",
    "    for i in range(depth):\n",
    "        H, cn = lstm(H, h0[i], c0[i], parameters[i])\n",
    "    l = loss(H, Y)\n",
    "    l.backward()\n",
    "    opt.step()\n",
    "```\n",
    "\n",
    "This training process (for both single and multi-layer cases) is known as \"backpropagation through time\" (BPTT), as we're essentially doing backprop (but automatically via the autodiff framework) over each time step in the sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2ee4c34",
   "metadata": {},
   "source": [
    "### Long sequences and truncated BPTT\n",
    "\n",
    "The process above works fine conceptually, but what happens in the case that we have a very long sequence?  For instance, in many language modeling tasks the true underlying sequence could be a document with thousands of words or audio signals that span many minutes.  Trying to train all of these in a \"single\" pass of BPTT would be:\n",
    "    \n",
    "1. Computationally/memory-wise infeasible (because we have to keep the whole computation graph in memory over the entire sequence).\n",
    "2. Inefficient for learning.  Just like batch gradient descent is inefficient from a learning standpoint relative to SGD, taking a single gradient step for an entire sequence is very inefficient from a learning perspective: we have to do a ton of work to get a single parameter update.\n",
    "    \n",
    "Thus, the simple solution is just to divide the sequence in multiple shorter blocks.  That is, we train the the LSTM on segments of the full sequence.  This could look something like the following.\n",
    "\n",
    "```python\n",
    "for i in range(0,X.shape[0],BLOCK_SIZE):\n",
    "    h0, c0 = zeros()\n",
    "    train_lstm(X[i:i+BLOCK_SIZE], Y[i:i+BLOCK_SIZE], \n",
    "               h0, c0, parameters)\n",
    "```\n",
    "\n",
    "This works, and \"solves\" the problem of long sequence lengths. But it is also unsatisfying: we got rid of long sequences by just chopping them into shorter sequences.  And this ignores the fact that it is precisely the long term dependencies (beyond `BLOCK_SIZE`) that are often most interesting in sequence models, i.e., language models that \"remember\" the general context of the words they are generating, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ef969c4",
   "metadata": {},
   "source": [
    "### Hidden repackaging\n",
    "\n",
    "The way around this is to use what is called \"hidden repackaging\".  At the end of running an LSTM on a sequence, we have the final hidden units (hidden and cell units) of the LSTM.  These embed the \"current\" state of the system, in a way.  We can't continue to continue to differentiate through these variables in the autodiff graph, but we _can_ input their raw values as the $h_0$ and $c_0$ variables into the LSTM run on the next chunk of data.  To do this, we'd want to adjust our LSTM training code to return these variables, but detached from their gradients.\n",
    "\n",
    "```python\n",
    "def train_lstm(X, Y, h0, c0, parameters)\n",
    "    H, cn = lstm(X, h0, c0, parameters)\n",
    "    l = loss(H, Y)\n",
    "    l.backward()\n",
    "    opt.step()\n",
    "    return H[-1].data, cn.data\n",
    "```\n",
    "\n",
    "We then use these values (instead of zeros), as the initial state of the LSTM in subsequent training loops.\n",
    "\n",
    "```python\n",
    "h0, c0 = zeros()\n",
    "for i in range(0,X.shape[0],BLOCK_SIZE):\n",
    "    h0, c0 = train_lstm(X[i:i+BLOCK_SIZE], Y[i:i+BLOCK_SIZE], h0, c0, parameters)\n",
    "```\n",
    "\n",
    "It's important to emphasize that this process is _still_ running truncated BPTT, as we're only computing gradients through a small portion of the full sequence.  But it's somewhat of a \"middle ground\" between doing full BPTT and always re-initializing the initial hidden states to zeros: the future LSTM states can get information from longer term context from the LSTM, and can use this to make its predictions of the future.  But it _cannot_ assess how changing the parameters of the LSTM would have changed this past initial state."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
